import logging
from dataclasses import dataclass
from io import BytesIO
from pathlib import PureWindowsPath
from typing import Any, Dict, Optional, Sequence, Tuple

from impacket.dcerpc.v5 import scmr, srvs, transport
from impacket.dcerpc.v5.rpcrt import DCERPC_v5
from impacket.smbconnection import SessionError, SMBConnection

from common.credentials import Credentials, LMHash, NTHash, Password
from common.types import NetworkPort
from infection_monkey.i_puppet import TargetHost

logger = logging.getLogger(__name__)

ERROR_SERVICE_REQUEST_TIMEOUT = 1053
ERROR_SERVICE_EXISTS = 1073


@dataclass
class ShareInfo:
    """Stores information about a SMB share"""

    name: str
    path: PureWindowsPath
    current_uses: int
    max_uses: int


def get_plaintext_secret(credentials: Credentials) -> str:
    secret = credentials.secret

    if isinstance(secret, Password):
        return secret.password.get_secret_value()

    if isinstance(secret, LMHash):
        return secret.lm_hash.get_secret_value()

    if isinstance(secret, NTHash):
        return secret.nt_hash.get_secret_value()

    return ""


class SMBClient:
    """Wraps an SMB connection and provides methods for interacting with it"""

    def __init__(self):
        self._smb_connection: Optional[SMBConnection] = None
        self._authenticated_credentials: Any = None
        self._authenticated = False

    def connected(self) -> bool:
        return self._authenticated

    def connect_with_user(self, host: TargetHost, credentials: Credentials, timeout: float):
        """
        Connect to target host over SMB

        :param host: A target host to which to connect
        :param credentials: Credentials to use when connecting
        :param timeout: SMB connection timeout
        :raise Exception: If connection fails
        """
        self._create_smb_connection(host)
        self._smb_login(credentials)
        self.set_timeout(timeout)
        if self._logout_guest():
            raise Exception("Logged in as guest")

    @property
    def _established_smb_connection(self) -> SMBConnection:
        if not self._smb_connection:
            raise Exception("SMB connection not established")
        return self._smb_connection

    def _create_smb_connection(self, host: TargetHost):
        """Connect to host over SMB. Raise Exception if connection fails"""
        try:
            # preferredDialect should be kept as None to choose the correct SMB version
            # For more context check https://github.com/guardicore/monkey/issues/3577
            self._smb_connection = SMBConnection(
                str(host.ip), str(host.ip), sess_port=445, preferredDialect=None
            )
            return
        except SessionError as err:
            logger.debug(f"Failed to create SMB connection to {host.ip} on port 445: {err}")

        try:
            # "*SMBSERVER" and port 139 is a special case. See doc for SMBConnection
            self._smb_connection = SMBConnection("*SMBSEVER", str(host.ip), sess_port=139)
            return
        except SessionError as err:
            logger.debug(f"Failed to create SMB connection to {host.ip} on port 139: {err}")

        raise Exception(f"Failed to create SMB connection to {host.ip}")

    def _smb_login(self, credentials: Credentials):
        """Raise SessionError if login fails"""

        self._established_smb_connection.login(
            user=credentials.identity.username,
            domain="",
            **self._build_args_for_secrets(credentials),
        )
        self._log_smb_dialect()
        self._authenticated = True
        self._authenticated_credentials = self._established_smb_connection.getCredentials()

    def _log_smb_dialect(self):
        try:
            smb_dialect = self._established_smb_connection.getDialect()
            logger.debug(f"SMB dialect is: {smb_dialect}")
        except Exception as err:
            logger.debug(f"Exception occured retrieving SMB dialect: {err}")

    @staticmethod
    def _build_args_for_secrets(credentials: Credentials) -> Dict[str, str]:
        args = {"password": ""}

        if isinstance(credentials.secret, Password):
            secret_type = "password"
        elif isinstance(credentials.secret, LMHash):
            secret_type = "lmhash"
        elif isinstance(credentials.secret, NTHash):
            secret_type = "nthash"
        else:
            return args

        args.update({secret_type: get_plaintext_secret(credentials)})
        return args

    def _logout_guest(self):
        """Return True if logged in as guest. Raise SessionError if logout fails"""
        smb_connection = self._established_smb_connection
        if smb_connection.isGuestSession() > 0:
            smb_connection.logoff()
            return True
        return False

    def connect_to_share(self, share_name: str):
        """
        Connects to a share over an active connection

        :param share_name: Name of the SMB share to connect to
        :raises SessionError: If an error occurred while connecting to share
        """
        self._established_smb_connection.connectTree(share_name)

    def query_shared_resources(self) -> Tuple[ShareInfo, ...]:
        """
        Get available network shares

        :return: A tuple of shares information
        """
        try:
            shares = self._execute_rpc_call(srvs.hNetrShareEnum, 2)
            shares = shares["InfoStruct"]["ShareInfo"]["Level2"]["Buffer"]
            return tuple(SMBClient._impacket_dict_to_share_info(share) for share in shares)
        except Exception as err:
            logger.debug(f"Failed to query shared resources: {err}")
            return ()

    @staticmethod
    def _impacket_dict_to_share_info(share_info_dict: Dict[str, Any]) -> ShareInfo:
        return ShareInfo(
            share_info_dict["shi2_netname"].strip("\0 "),
            PureWindowsPath(share_info_dict["shi2_path"].strip("\0 ")),
            share_info_dict["shi2_current_uses"],
            share_info_dict["shi2_max_uses"],
        )

    def _execute_rpc_call(self, rpc_func, *args) -> Any:
        """
        Executes an RPC call using DCE/RPC transport protocol

        :param rpc_func: Helpers' RPC function
        :raises SessionError: If an error occurs while executing an RPC call
        """
        smb_connection = self._established_smb_connection
        rpc_transport = transport.SMBTransport(
            smb_connection.getRemoteHost(),
            smb_connection.getRemoteHost(),
            filename=r"\srvsvc",
            smb_connection=smb_connection,
        )

        rpc = SMBClient._dce_rpc_connect(rpc_transport)
        rpc.bind(srvs.MSRPC_UUID_SRVS)

        return rpc_func(rpc, *args)

    def run_service(
        self,
        service_name: str,
        command: str,
        host: TargetHost,
        ports_to_try: Sequence[NetworkPort],
        timeout: float,
    ):
        """
        Run a command as a service on the remote host.

        :param service_name: Name to give the service to run
        :param command: Command to be run
        :param host: Target host on which to run the service
        :param ports_to_try: A list of network ports
        :param timeout: Timeout to use for the RPC connection
        :raises Exception: If an error occurred while connecting over SMB
        """
        rpc = self._rpc_connect(host, ports_to_try, timeout)
        rpc.bind(scmr.MSRPC_UUID_SCMR)
        resp = scmr.hROpenSCManagerW(rpc)
        sc_handle = resp["lpScHandle"]

        try:
            resp = scmr.hRCreateServiceW(
                rpc,
                sc_handle,
                service_name,
                service_name,
                lpBinaryPathName=command,
            )
        except scmr.DCERPCSessionError as err:
            if err.error_code == ERROR_SERVICE_EXISTS:
                logger.debug(f"Service '{service_name}' already exists, trying to start it")
                resp = scmr.hROpenServiceW(rpc, sc_handle, service_name)
            else:
                raise err

        service_handle = resp["lpServiceHandle"]
        try:
            scmr.hRStartServiceW(rpc, service_handle)
        except scmr.DCERPCSessionError as err:
            # Since we're abusing the Windows SCM, we should expect ERROR_SERVICE_REQUEST_TIMEOUT
            # because we're not running a real service, which would call
            # StartServiceCtrlDispatcher() and prevent this error
            if not err.error_code == ERROR_SERVICE_REQUEST_TIMEOUT:
                raise Exception("Failed to start the service")
        finally:
            scmr.hRDeleteService(rpc, service_handle)
            scmr.hRCloseServiceHandle(rpc, service_handle)

    def _rpc_connect(
        self,
        host: TargetHost,
        ports: Sequence[NetworkPort],
        timeout: float,
    ) -> DCERPC_v5:
        """Connects to the remote host and returns the RPC connection"""

        # Try to use the existing SMB connection
        try:
            smb_transport = transport.SMBTransport(
                self._established_smb_connection.getRemoteName(),
                filename="\\svcctl",
                smb_connection=self._smb_connection,
            )
            return SMBClient._dce_rpc_connect(smb_transport)
        except Exception as err:
            logger.debug(f"Failed to use existing SMB connection for RPC: {err}")

        for port in ports:
            try:
                return self._rpc_connect_to_port(host, port, timeout)
            except Exception as err:
                logger.debug(f"Failed to create RPC connection on port {port}: {err}")
        raise Exception("Failed to establish an RPC connection over SMB")

    def _rpc_connect_to_port(
        self, host: TargetHost, port: NetworkPort, timeout: float
    ) -> DCERPC_v5:
        """
        Connects to the remote host over the specified port and returns the RPC connection.
        :raises Exception: If connection fails
        """
        rpc_transport = transport.DCERPCTransportFactory(f"ncacn_np:{host.ip}[\\pipe\\svcctl]")
        rpc_transport.set_connect_timeout(timeout)
        rpc_transport.set_dport(int(port))
        rpc_transport.setRemoteHost(str(host.ip))
        rpc_transport.set_credentials(*self._authenticated_credentials)
        rpc_transport.set_kerberos(False)

        rpc = SMBClient._dce_rpc_connect(rpc_transport)
        smb = rpc_transport.get_smb_connection()
        smb.setTimeout(timeout)
        return rpc

    @staticmethod
    def _dce_rpc_connect(rpc_transport) -> DCERPC_v5:
        """
        Establishes a DCE/RPC connection over a given transport stream
        :return: A DCE/RPC connection
        :raises Exception: If an error occurred while connecting to the remote host
        """
        rpc = rpc_transport.get_dce_rpc()
        rpc.connect()
        return rpc

    def send_file(self, share_name: str, path_name: PureWindowsPath, file: bytes):
        """
        Send a file to the remote host

        :param share_name: A network share name
        :param path_name: A remote network share path
        :param callback: File to copy to the remote host
        :raises Exception: If an error occurred while sending the file
        """
        file_io = BytesIO(file)
        self._established_smb_connection.putFile(share_name, str(path_name), file_io.read)

    def set_timeout(self, timeout: float):
        """
        Set the connection timeout

        :param timeout: Connection timeout, in seconds
        :raises Exception: If an error occurs
        """
        self._established_smb_connection.setTimeout(timeout)
